{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "maxlen = 20\n",
    "max_vocab = 20000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "word2idx = tf.keras.datasets.imdb.get_word_index()\n",
    "word2idx = {k: (v + 4) for k, v in word2idx.items()}\n",
    "word2idx['<PAD>'] = 0\n",
    "word2idx['<START>'] = 1\n",
    "word2idx['<UNK>'] = 2\n",
    "word2idx['<END>'] = 3\n",
    "idx2word = {i: w for w, i in word2idx.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "(train_X, _), (test_X, _) = tf.contrib.keras.datasets.imdb.load_data(num_words = max_vocab, index_from= 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = np.concatenate([train_X, test_X])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = np.concatenate((tf.keras.preprocessing.sequence.pad_sequences(\n",
    "                            X, maxlen, truncating='post', padding='post'),\n",
    "                        tf.keras.preprocessing.sequence.pad_sequences(\n",
    "                            X, maxlen, truncating='pre', padding='post')))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_input = X[:]\n",
    "Y_output = np.concatenate([X[:, 1:], np.full([X.shape[0], 1], word2idx['<END>'])], 1)\n",
    "X = X[:, 1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((100000, 19), (100000, 20), (100000, 20))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape, Y_input.shape, Y_output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import attention_wrapper\n",
    "import basic_decoder\n",
    "import decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VAE:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size, dict_size, learning_rate,\n",
    "                dropout = 0.8):\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y_input = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y_output = tf.placeholder(tf.int32, [None, None])\n",
    "        self.lambda_coeff = tf.placeholder(tf.float32, shape=())\n",
    "        self.attention_temperature = tf.placeholder(tf.float32, shape=())\n",
    "        \n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "        \n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype=tf.int32)\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y_input, 1, dtype=tf.int32)\n",
    "        \n",
    "        encoder_embeddings = tf.Variable(tf.random_uniform([dict_size, embedded_size], -1, 1))\n",
    "        decoder_embeddings = tf.Variable(tf.random_uniform([dict_size, embedded_size], -1, 1))\n",
    "        encoder_embedded = tf.nn.embedding_lookup(encoder_embeddings, self.X)\n",
    "        decoder_embedded = tf.nn.embedding_lookup(decoder_embeddings, self.Y_input)\n",
    "        \n",
    "        main = tf.strided_slice(self.Y_input, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], word2idx['<START>']), main], 1)\n",
    "        \n",
    "        for i in range(num_layers):\n",
    "            with tf.variable_scope('encoder_%d'%(i)):\n",
    "                cell_fw = tf.contrib.rnn.LayerNormBasicLSTMCell(size_layer)\n",
    "                cell_fw = tf.contrib.rnn.DropoutWrapper(cell_fw, input_keep_prob=dropout)\n",
    "\n",
    "                cell_bw = tf.contrib.rnn.LayerNormBasicLSTMCell(size_layer)\n",
    "                cell_bw = tf.contrib.rnn.DropoutWrapper(cell_bw, input_keep_prob=dropout)\n",
    "                \n",
    "                self.enc_output, self.enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,\n",
    "                                                                                  cell_bw,\n",
    "                                                                                  encoder_embedded,\n",
    "                                                                                  self.X_seq_len,\n",
    "                                                                                  dtype=tf.float32)\n",
    "        \n",
    "        self.encoder_state = tf.concat([self.enc_state[0][1], self.enc_state[1][1]], axis=-1)\n",
    "        self.encoder_out = tf.concat(self.enc_output, 2)\n",
    "        self.z_mean = tf.layers.dense(self.encoder_state, size_layer)\n",
    "        self.z_log_sigma = tf.layers.dense(self.encoder_state, size_layer)\n",
    "        \n",
    "        epsilon = tf.random_normal(tf.shape(self.z_log_sigma))\n",
    "        self.z_vector = self.z_mean + tf.exp(self.z_log_sigma)\n",
    "        \n",
    "        dense = tf.layers.Dense(dict_size)\n",
    "        \n",
    "        decoder_cells = []\n",
    "        for i in range(num_layers):\n",
    "            dec_cell = tf.contrib.rnn.LayerNormBasicLSTMCell(2 * size_layer)\n",
    "            dec_cell = tf.contrib.rnn.DropoutWrapper(dec_cell, input_keep_prob=dropout)\n",
    "            decoder_cells.append(dec_cell)\n",
    "        \n",
    "        self.decoder_cells = tf.nn.rnn_cell.MultiRNNCell(decoder_cells)\n",
    "        \n",
    "        attn_mech = attention_wrapper.LuongAttention(2 * size_layer,\n",
    "                                                      self.encoder_out,\n",
    "                                                      memory_sequence_length=self.X_seq_len)\n",
    "        \n",
    "        attn_cell = attention_wrapper.AttentionWrapper(self.decoder_cells, attn_mech,\n",
    "                                                       0.5,\n",
    "                                                       True,\n",
    "                                                       size_layer)\n",
    "        \n",
    "        self.init_state = attn_cell.zero_state(batch_size, tf.float32)\n",
    "        \n",
    "        training_helper = tf.contrib.seq2seq.TrainingHelper(inputs=decoder_embedded,\n",
    "                                                            sequence_length=self.Y_seq_len,\n",
    "                                                            time_major=False)\n",
    "        \n",
    "        training_decoder = basic_decoder.BasicDecoder(attn_cell,\n",
    "                                                training_helper,\n",
    "                                                initial_state=self.init_state,\n",
    "                                                latent_vector=self.z_vector,\n",
    "                                                output_layer=dense)\n",
    "        self.training_logits, _, _, self.c_kl_batch_train = decoder.dynamic_decode(training_decoder,\n",
    "                                                                       output_time_major=False,\n",
    "                                                                       impute_finished=True,\n",
    "                                                                       maximum_iterations=\n",
    "                                                                       tf.reduce_max(self.Y_seq_len))\n",
    "        self.c_kl_batch_train = tf.div(self.c_kl_batch_train, \n",
    "                                       tf.cast(self.Y_seq_len, dtype=tf.float32))\n",
    "        self.training_logits = self.training_logits.rnn_output\n",
    "        \n",
    "        start_tokens = tf.tile(tf.constant([word2idx['<START>']], dtype=tf.int32), \n",
    "                               [batch_size])\n",
    "        inference_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(decoder_embeddings,\n",
    "                                                                    start_tokens,\n",
    "                                                                    word2idx['<END>'])\n",
    "        inference_decoder = basic_decoder.BasicDecoder(attn_cell,\n",
    "                                                         inference_helper,\n",
    "                                                         initial_state=self.init_state,\n",
    "                                                         latent_vector=self.z_vector,\n",
    "                                                         output_layer=dense)\n",
    "        self.inference_logits, _, _, _ = decoder.dynamic_decode(inference_decoder,\n",
    "                                                                        output_time_major=False,\n",
    "                                                                        impute_finished=True,\n",
    "                                                                        maximum_iterations=\n",
    "                                                                        tf.reduce_max(self.X_seq_len))\n",
    "        self.logits = self.inference_logits.sample_id\n",
    "        self.kl_loss = -0.5 * tf.reduce_sum(1.0 + 2 * self.z_log_sigma - self.z_mean ** 2 - \n",
    "                             tf.exp(2 * self.z_log_sigma), 1)\n",
    "        self.kl_loss = tf.scalar_mul(self.lambda_coeff, self.kl_loss)\n",
    "        self.context_kl_loss = self.c_kl_batch_train * 0.1\n",
    "        \n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,\n",
    "                                                     targets = self.Y_output,\n",
    "                                                     weights = masks)\n",
    "        self.cost = tf.reduce_sum(self.cost + self.kl_loss + self.context_kl_loss)\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.cost)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 128\n",
    "num_layers = 2\n",
    "embedded_size = 128\n",
    "learning_rate = 1e-3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py:1702: UserWarning: An interactive session is already active. This can cause out-of-memory errors in some cases. You must explicitly call `InteractiveSession.close()` to release resources held by the other session(s).\n",
      "  warnings.warn('An interactive session is already active. This can '\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = VAE(size_layer, num_layers, embedded_size, len(word2idx), learning_rate)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def word_dropout(x):\n",
    "    is_dropped = np.random.binomial(1, 0.8, x.shape)\n",
    "    fn = np.vectorize(lambda x, k: word2idx['<UNK>'] if (\n",
    "                      k and (x not in range(4))) else x)\n",
    "    return fn(x, is_dropped)\n",
    "\n",
    "def inf_inp(test_strs):\n",
    "    x = [[word2idx.get(w, 2) for w in s.split()] for s in test_strs]\n",
    "    x = tf.keras.preprocessing.sequence.pad_sequences(\n",
    "        x, maxlen, truncating='post', padding='post')\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_strings = ['i love this film and i think it is one of the best films',\n",
    "             'this movie is a waste of time and there is no point to watch it']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_input = word_dropout(Y_input[:2])\n",
    "y_output = Y_output[:2]\n",
    "x = X[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[22.949621, array([0., 0.], dtype=float32)]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sess.run([model.cost, model.kl_loss],\n",
    "         feed_dict = {model.X: x, model.Y_input: y_input,\n",
    "                      model.Y_output: y_output,\n",
    "                      model.lambda_coeff: 0})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'hauled wildfire idiom pontente warthog illustration madelein dedede intuition piteously kaczorowski observationally odds sacrilage yack'"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r_aug = sess.run(model.logits, feed_dict = {model.X: inf_inp(test_strings)})[0]\n",
    "' '.join([idx2word[r] for r in r_aug])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 10\n",
    "batch_size = 32\n",
    "iter_i = 0\n",
    "lambda_val = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [19:45<00:00,  2.66it/s, cost=2.94]\n",
      "minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, average loss 56.320713\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: i love this film and i think it is one of the best films films\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [19:35<00:00,  2.67it/s, cost=0.426]\n",
      "minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2, average loss 4.421890\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: i love this film and i think it is one of the best films films\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [19:35<00:00,  2.67it/s, cost=0.171]\n",
      "minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3, average loss 2.081798\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: i love this film and i think it is one of the best films films\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [19:34<00:00,  2.65it/s, cost=0.146]\n",
      "minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 4, average loss 1.430513\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: i love this film and i think it is one of the best films films\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [19:36<00:00,  2.66it/s, cost=0.831] \n",
      "minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 5, average loss 1.257294\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: i love this film and i think it is one of the best films films\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [19:35<00:00,  2.65it/s, cost=0.616]\n",
      "minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 6, average loss 1.620867\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: i love this film and i think it is one of the best films films\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [19:35<00:00,  2.67it/s, cost=0.475]\n",
      "minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 7, average loss 0.946291\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: i love this film and i think it is one of the best films i\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [19:36<00:00,  2.66it/s, cost=0.306]\n",
      "minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 8, average loss 2.231685\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: i love this film and i think it is one of the best films films\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [19:36<00:00,  2.65it/s, cost=0.048] \n",
      "minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 9, average loss 0.607313\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: i love this film and i think it is one of the best films love\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 3125/3125 [19:35<00:00,  2.64it/s, cost=0.0188]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 10, average loss 2.848585\n",
      "real string: i love this film and i think it is one of the best films\n",
      "augmented string: i love this film and i think it is one of the best films love\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for e in range(epoch):\n",
    "    pbar = tqdm(\n",
    "        range(0, len(X), batch_size), desc = 'minibatch loop')\n",
    "    cost = 0\n",
    "    for i in pbar:\n",
    "        iter_i += 1\n",
    "        if iter_i <= 3000:\n",
    "            lambda_val = np.round((np.tanh((iter_i - 4500) / 1000) + 1) / 2, decimals=6)\n",
    "            \n",
    "        index = min(i + batch_size, len(X))\n",
    "        y_input = word_dropout(Y_input[i: index])\n",
    "        y_output = Y_output[i: index]\n",
    "        x = X[i: index]\n",
    "        c, _ = sess.run([model.cost, model.optimizer],\n",
    "         feed_dict = {model.X: x, model.Y_input: y_input,\n",
    "                      model.Y_output: y_output,\n",
    "                      model.lambda_coeff: lambda_val})\n",
    "        cost += c\n",
    "        pbar.set_postfix(cost = c)\n",
    "    cost /= (len(X) / batch_size)\n",
    "    r_aug = sess.run(model.logits, feed_dict = {model.X: inf_inp(test_strings)})[0]\n",
    "    print('epoch %d, average loss %f'%(e + 1, cost))\n",
    "    print('real string: %s'%(test_strings[0]))\n",
    "    print('augmented string: %s'%(' '.join([idx2word[r] for r in r_aug])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
